package mux

import (
	"log"
	"net/http"
	"strings"
	"sync"

	"gitee.com/go-wena/env"
	"github.com/go-chi/chi"
	"github.com/go-chi/chi/middleware"
)

type HandlerFunc = func(ctx *Context)
type gMap map[string]map[string]HandlerFunc
type Mux struct {
	g  gMap
	p  sync.Pool
	mu sync.Mutex
}

func New() *Mux {
	m := &Mux{g: gMap{}}
	m.p.New = func() interface{} { return &Context{} }
	return m
}

func (m *Mux) getContext(w http.ResponseWriter, q *http.Request) *Context {
	c := m.p.Get().(*Context)
	c.w = w
	c.q = q
	return c
}

func (m *Mux) wrap(handler HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, req *http.Request) {
		c := m.getContext(w, req)
		defer m.p.Put(c)
		handler(c)
	}
}

func (m *Mux) Handle(method string, pattern string, handler HandlerFunc) {
	m.mu.Lock()
	defer m.mu.Unlock()
	method = strings.ToUpper(method)
	if m.g == nil {
		m.g = map[string]map[string]HandlerFunc{method: {pattern: handler}}
	} else if m.g[method] == nil {
		m.g[method] = map[string]HandlerFunc{pattern: handler}
	} else {
		m.g[method][pattern] = handler
	}
}

func (m *Mux) Handler() http.Handler {
	mux := chi.NewMux()
	mux.Use(middleware.Recoverer, middleware.CleanPath, middleware.StripSlashes)

	if !env.IsFalse("router.middleware.logger") {
		mux.Use(middleware.Logger)
	}
	if !env.IsFalse("router.middleware.request_id") {
		mux.Use(middleware.RequestID)
	}
	if !env.IsFalse("router.middleware.real_ip") {
		mux.Use(middleware.RealIP)
	}
	if gzl := env.GetInt("router.middleware.gzip"); gzl > 0 {
		mux.Use(middleware.Compress(int(gzl)))
	}

	mux.Group(m.mapped)
	mux.Group(m.profiler)

	_ = chi.Walk(mux, func(method string, route string, handler http.Handler, _ ...func(http.Handler) http.Handler) error {
		log.Printf("[mux] %-7s %s", method, route)
		return nil
	})

	return mux
}

func (m *Mux) mapped(router chi.Router) {
	for method, patternMap := range m.g {
		for pattern, handler := range patternMap {
			h := m.wrap(handler)
			methods := splitMethod(method)
			if len(methods) == 0 {
				router.Handle(pattern, h)
			} else {
				for _, method := range methods {
					if method != "" && method != "*" {
						router.Method(method, pattern, m.wrap(handler))
					}
				}
			}
		}
	}
}

func splitMethod(method string) []string {
	methods := strings.Fields(method)
	if len(methods) == 0 || (len(methods) == 1 && (methods[0] == "" || methods[0] == "*")) {
		return nil
	}
	return methods
}
